import torch
import torch.nn as nn
from causally.model.utils import get_linear_layers
from causally.model.abstract_model import AbstractModel
from causally.model.utils import cal_wass

class DESCN(AbstractModel):
    def __init__(self, config,dataset):
        super(DESCN, self).__init__(config,dataset)
        self.in_feature = self.dataset.size[1]
        self.alpha = self.config['alpha']
        self.beta = self.config['beta']
        self.theta = self.config['theta']
        self.phi = self.config['phi']

        self.bn = self.config['bn']
        self.repre_layer_sizes = self.config['repre_layer_sizes']
        self.pred_layer_sizes = self.config['pred_layer_sizes']
        self.prop_layer_sizes = self.config['prop_layer_sizes']
        self.tau_layer_sizes = self.config['tau_layer_sizes']

        self.repre_layers = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature)] if self.bn else [])
                                             + get_linear_layers(self.in_feature,self.repre_layer_sizes,self.bn,nn.ReLU)))

        self.pred_layers_treated = nn.Sequential(*get_linear_layers(self.repre_layer_sizes[-1],
                                                                    self.pred_layer_sizes, False, nn.ReLU))

        self.pred_layers_treated.add_module('out1',nn.Linear(self.pred_layer_sizes[-1],1))
        self.pred_layers_control = nn.Sequential(*get_linear_layers(self.repre_layer_sizes[-1],
                                                                    self.pred_layer_sizes, False, nn.ReLU))
        self.pred_layers_control.add_module('out0', nn.Linear(self.pred_layer_sizes[-1],1))

        self.prop_layers = nn.Sequential(*get_linear_layers(self.repre_layer_sizes[-1], self.prop_layer_sizes, False, nn.ReLU))
        self.prop_layers.add_module('out_prop', nn.Linear(self.prop_layer_sizes[-1],1))

        self.tau_layers = nn.Sequential(*get_linear_layers(self.repre_layer_sizes[-1], self.tau_layer_sizes, False, nn.ReLU))
        self.tau_layers.add_module('out_tau', nn.Linear(self.tau_layer_sizes[-1],1))

        self.mse_loss = nn.MSELoss(reduction='none')
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, t):
        t = t.squeeze(-1)
        self.repre = self.repre_layers(x)
        self.out_1 = self.pred_layers_treated(self.repre)
        self.out_0 = self.pred_layers_control(self.repre)
        y = torch.where(t.unsqueeze(-1) == 1, self.out_1, self.out_0)
        return y

    def get_repre(self, x, device):
        self.eval()
        with torch.no_grad():
            return self.repre_layers.to(device)(x.to(device))

    def calculate_loss(self, x,t,y,w):
        pred_1 = self.forward(x,torch.ones(x.shape[0]).to(self.device))
        pred_0 = self.forward(x,torch.zeros(x.shape[0]).to(self.device))

        x_repre = self.repre_layers(x)
        prop = self.sigmoid(self.prop_layers(x_repre))
        tau = self.tau_layers(x_repre)

        L_pi = torch.sum(self.mse_loss(prop.float(), t.float()))
        L_ESTR = torch.sum(self.mse_loss(prop.float() * pred_1.float(),
                                         t.float() * y.float()))
        L_ESCR = torch.sum(self.mse_loss((torch.ones_like(t).float().to(self.device) - prop.float()) * pred_0.float(),
                                         (torch.ones_like(t).float().to(self.device) - t.float()) * y.float()))
        L_CrossTR = torch.sum(self.mse_loss(t.float() * (tau.float() + pred_0.float()), t.float() * y.float()))
        L_CrossCR = torch.sum(self.mse_loss((torch.ones_like(t).float().to(self.device) - t.float()) * (- tau.float() + pred_1.float()),
                                            (torch.ones_like(t).float().to(self.device) - t.float()) * y.float()))

        loss = (1./len(t))*(L_pi + self.alpha * L_ESTR + self.beta * L_ESCR + self.theta * L_CrossTR + self.phi * L_CrossCR)
        return loss

    def predict(self, x,t):
        r"""Predict the scores between users and items.

        Args:
            interaction (Interaction): Interaction class of the batch.

        Returns:
            torch.Tensor: Predicted scores for given users and items, shape: [batch_size]
        """
        y = self.forward(x, t)
        if self.loss_type == 'MSE':
            return y
        else:
            return torch.sigmoid(y)